import multiprocessing as mp
import logging
import os
from os import makedirs
from os.path import join, abspath, dirname, pardir
import numpy as np
import subprocess
import argparse
from scapy.all import *
import glob
#from common import My_Source_Ips

My_Source_Ips = {'192.168.243.77', '192.168.241.170'}


#CELL_SIZE = 512
# Cell Size is 514 in new version of Tor
CELL_SIZE = 514
#CELL+ TLS HEADER + MY HEADER
MY_CELL_SIZE = CELL_SIZE
isDummy = 888
isReal = 1

captured_file_name = '.pcap'
#captured_file_name = '.pcap.filtered'
ParsedDir = join(abspath(join(dirname(__file__), pardir)) , "traffic_sequence/parsed")

def init_directories(path):
    # Create a results dir if it doesn't exist yet
    if not os.path.exists(path):
        makedirs(path)


def getTimestamp(pkt, t0):
    return float(pkt.time - t0)


def getDirection(pkt):
    if pkt.payload.src in My_Source_Ips:
        return 1
    else:
        return -1

def parse_arguments():

    parser = argparse.ArgumentParser(description='Parse captured traffic.')

    parser.add_argument('dir',
                        type=str,
                        metavar='<dataset path>',
                        help='Path of dataset.')
    parser.add_argument('-mode',
                        type=str,
                        metavar='<parse mode>',
                        help='The type of dataset: clean, burst?.')
    parser.add_argument('-u',
                        action='store_true',
                        default=False,
                        help='is monitored webpage or unmonitored? (default:is monitored, false)')
    parser.add_argument('-s',
                        action='store_true',
                        default=False,
                        help='If use screenshot as sanity check?')
    parser.add_argument('-suffix',
                        type=str,
                        metavar='<parsed file suffix>',
                        default='.cell',
                        help='to save file as xx.suffix')
    parser.add_argument('-proc_num',
                        type=int,
                        metavar='<process num>',
                        default=2,
                        help='The num of CPU')
    # Parse arguments
    args = parser.parse_args()
    return args

# TOR CELL EXTRACTION LOGIC
def clean_parse(fdir):
    global savedir, suffix, isunmon
    if isunmon:
        site = fdir.split("/")[-1].split(".pcap")[0]
        savefiledir = join(savedir, site+suffix)
    else:
        # site,inst = fdir.split("/")[-1].split(".pcap")[0].split("-")
        # savefiledir = join(savedir, site+"-"+inst+suffix)
        site = fdir.split("/")[-1].split(".pcap")[0]
        savefiledir = join(savedir, site+suffix)
        print(savefiledir)

    print("Reading PCAP: ")
    packets = rdpcap(fdir)
    if len(packets) < 50:
        print("[WARN] {} has too few packets, skip!".format(fdir))
        return
    try:
        with open(savefiledir, 'w') as f:
            start = 0
            t0 = packets[0].time
            # for i, pkt in enumerate(packets):
            #     #skip the first few noise packets
            #     if getDirection(pkt)>0 :
            #         start = i
            #         t0 = pkt.time
            #         print("Start from pkt no. {}".format(start))
            #         break

            for i, pkt in enumerate(packets[start:]):
                b = raw(pkt.payload.payload.payload)
                byte_ind = b.find(b'\x17\x03\x03')
                while byte_ind != -1 and byte_ind < len(b):
                    if b[byte_ind:byte_ind + 3] == b'\x17\x03\x03':
                        TLS_LEN = int.from_bytes(b[byte_ind+3:byte_ind+5], 'big')
                        cur_time = getTimestamp(pkt,t0)
                        cur_dir = getDirection(pkt)
                        #complete TLS record
                        cell_num = TLS_LEN /CELL_SIZE
                        cell_num = int(np.round(cell_num))
                        for i in range(cell_num):
                            f.write("{:.6f}\t{:d}\n".format(cur_time, cur_dir))
                        byte_ind += TLS_LEN + 5
                    else:
                        break
    except Exception as e:
        print("Error in {}, {} ".format(fdir.split('/')[-1], e))


if __name__ == "__main__":
    global savedir, suffix, isunmon
    args = parse_arguments()
    suffix = args.suffix
    isunmon = args.u
    # filelist = glob.glob(join(args.dir,'*_*_*' ,'capture.pcap.filtered'))
    filename = args.dir.rstrip("/").split("/")[-1]
    savedir = join(ParsedDir, filename)
    init_directories(savedir)
    print("Parsed file in {}".format(savedir))
    if args.s:
        filelist_ = glob.glob(join(args.dir,'*.png'))
        filelist = []
        #Sanity check
        for f in filelist_:
            pcapfile = f.split(".png")[0] + captured_file_name
            if os.path.exists(pcapfile):
                filelist.append(pcapfile)
    else:
        filelist =  glob.glob(join(args.dir,'*'+captured_file_name))
    print(filelist)

    # for f in filelist:
    #   parse(f)
    print("Total:{}".format(len(filelist)))
    pool = mp.Pool(processes=args.proc_num)
    if args.mode == 'clean':
        # pool.map(clean_parse, filelist)
        pool.map(clean_parse, filelist)
    else:
        raise ValueError('Wrong mode:{}'.format(args.mode))

    # zipcmd = "zip -rq " + savedir.rstrip("/") + ".zip" + " " + savedir
    # print(zipcmd)
    # subprocess.call(zipcmd, shell=True)

